import os
import sys
import random
import math
import re
import time
import numpy as np
import tensorflow as tf
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as patches
ROOT_DIR = os.getcwd()
sys.path.append(ROOT_DIR)
import utils
import visualize
from visualize import display_images
import model as modellib
from model import log
import cell
%matplotlib inline
MODEL_DIR = os.path.join(ROOT_DIR, "logs")
CELL_WEIGHTS_PATH = "./mask_rcnn_cell.h5"
config = cell.CellConfig()
CELL_DIR = os.path.join(ROOT_DIR, "datasets/cell")
class InferenceConfig(config.__class__):
GPU_COUNT = 1
IMAGES_PER_GPU = 1
config = InferenceConfig()
config.display()
DEVICE = "/cpu:0" # /cpu:0 or /gpu:0
TEST_MODE = "inference"
def get_ax(rows=1, cols=1, size=16):
_, ax = plt.subplots(rows, cols, figsize=(size*cols, size*rows))
return ax
dataset = cell.CellDataset()
dataset.load_cell(CELL_DIR, "val")
dataset.prepare()
print("Images: {}\nClasses: {}".format(len(dataset.image_ids), dataset.class_names))
with tf.device(DEVICE):
model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR,
config=config)
import tensorflow as tf
print(tf.__version__)
weights_path = "./mask_rcnn_cell.h5"
print("Loading weights ", weights_path)
model.load_weights(weights_path, by_name=True)
dataset.image_ids
image_id = random.choice(dataset.image_ids)
image, image_meta, gt_class_id, gt_bbox, gt_mask =\
modellib.load_image_gt(dataset, config, image_id, use_mini_mask=False)
info = dataset.image_info[image_id]
print("image ID: {}.{} ({}) {}".format(info["source"], info["id"], image_id,
dataset.image_reference(image_id)))
results = model.detect([image], verbose=1)
ax = get_ax(1)
r = results[0]
visualize.display_instances(image, r['rois'], r['masks'], r['class_ids'],
dataset.class_names, r['scores'], ax=ax,
title="Predictions")
log("gt_class_id", gt_class_id)
log("gt_bbox", gt_bbox)
log("gt_mask", gt_mask)
from mrcnn import visualize as mrcnn_visualize
mrcnn_visualize.display_differences(
image,
gt_bbox, gt_class_id, gt_mask,
r['rois'], r['class_ids'], r['scores'], r['masks'],
dataset.class_names, ax=get_ax(),
show_box=True, show_mask=False,
iou_threshold=0.1, score_threshold=0.1)
# Draw precision-recall curve
AP, precisions, recalls, overlaps = utils.compute_ap(gt_bbox, gt_class_id, gt_mask,
r['rois'], r['class_ids'], r['scores'], gt_mask) #r['masks']
visualize.plot_precision_recall(AP, precisions, recalls)